import time
import numpy as np

import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.utils.plot import Ploter

from source.data import multip_thread_reader
from source.model import YOLOv3
from source.loss import get_sum_loss

num_classes = 7                                                                              # 类别数量
anchor_size = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] # 锚框大小
anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]                                              # 锚框掩码
ignore_threshold = 0.7                                                                       # 忽略阈值
downsample_ratio = 32                                                                        # 下采样率

def train(epoch_numb=50,
          learning_rate=[0.001, 0.0005, 0.00025, 0.00002, 0.00001],
          lr_iterations=[4250, 8500, 12750, 17000],
          model_path='./output/darknet53-yolov3',
          train_path='./dataset/train/',
          valid_path='./dataset/val/'):
    """
    功能:
        训练模型
    输入:
        epoch_numb    - 迭代周期
        learning_rate - 学习率
        lr_iterations - 学习率迭代数
        model_path    - 模型保存路径
        train_path    - 训练数据路径
        valid_path    - 验证数据路径
    输出:
    """
    with fluid.dygraph.guard(): # 动态图域
        # 准备数据
        train_reader = multip_thread_reader(train_path, 4, 'train') # 读取训练数据
        valid_reader = multip_thread_reader(valid_path, 2, 'valid') # 读取验证数据

        # 声明模型
        model = YOLOv3(num_classes=num_classes, anchor_mask=anchor_mask)

        # 优化算法
        optimizer = fluid.optimizer.Momentum(
            learning_rate=fluid.layers.piecewise_decay(boundaries=lr_iterations, values=learning_rate),
            momentum=0.9,
            parameter_list=model.parameters())

        # 开始训练
        train_prompt = "Train loss"                 # 训练标签
        valid_prompt = "Valid loss"                 # 验证标签
        ploter = Ploter(train_prompt, valid_prompt) # 损失图像
        
        iterator = 1                                # 迭代次数
        iter_time = time.time()                     # 迭代时间
        best_epoch = 0                              # 最好周期
        best_loss = 100000.0                        # 最好损失
        
        for epoch_id in range(epoch_numb):
            # 训练模型
            model.train() # 设置训练
            for batch_id, train_data in enumerate(train_reader()):
                # 读取数据
                image, gtbox, gtcls, image_size = train_data # 读取一条数据
                image = to_variable(image)                   # 转换数据格式

                # 前向传播
                infer = model(image)

                # 计算损失
                loss = get_sum_loss(
                    infer, gtbox, gtcls, num_classes, anchor_size, anchor_mask, ignore_threshold, downsample_ratio)

                # 反向传播
                loss.backward()          # 反向传播
                optimizer.minimize(loss) # 更新权重
                model.clear_gradients()  # 清除梯度
                
                # 显示损失
                if iterator % 10 == 0:                                  # 显示损失次数
                    ploter.append(train_prompt, iterator, loss.numpy()) # 添加损失数值
                    ploter.plot()                                       # 显示损失图像
                    print("train - iter: {:5d}, epoch: {:3d}, loss: {[0]:.3f}, best loss:{:.3f}".format(
                        iterator, epoch_id, loss.numpy(), best_loss))
                iterator += 1 # 增加损失次数

            # 验证模型
            loss_list = [] # 损失列表
            model.eval()   # 设置验证
            for batch_id, valid_data in enumerate(valid_reader()):
                # 读取数据
                image, gtbox, gtcls, image_size = valid_data # 读取一条数据
                image = to_variable(image)                   # 转换数据格式

                # 前向传播
                infer = model(image)

                # 计算损失
                loss = get_sum_loss(
                    infer, gtbox, gtcls, num_classes, anchor_size, anchor_mask, ignore_threshold, downsample_ratio)

                # 记录损失
                loss_list.append(loss.numpy())

            # 显示损失
            mean_loss = np.mean(loss_list)                   # 计算验证损失
            ploter.append(valid_prompt, iterator, mean_loss) # 添加损失数值
            ploter.plot()                                    # 显示损失图像
            print("valid - iter: {:5d}, epoch: {:3d}, loss: {:.3f}".format(iterator, epoch_id, mean_loss))

            # 保存模型
            if mean_loss < best_loss:
                fluid.save_dygraph(model.state_dict(), model_path) # 保存模型
                best_loss = mean_loss                              # 更新损失
                best_epoch = epoch_id                              # 更新迭代

        # 显示时间
        iter_time = time.time() - iter_time # 总的时间
        print("best - epoch:{:4d}, loss:{:.3f}, time: {:.0f}s".format(best_epoch ,best_loss, iter_time))